{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5a01f3ac-5306-4a1b-9e47-a5d254bce93a",
   "metadata": {},
   "source": [
    "# Understanding Validators\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9dcc78ac-ed6d-49e3-b71b-fb2fb25f16a8",
   "metadata": {},
   "source": [
    "Pydantic offers an customizable and expressive validation framework for Python. Instructor leverages Pydantic's validation framework to provide a uniform developer experience for both code-based and LLM-based validation, as well as a reasking mechanism for correcting LLM outputs based on validation errors. To learn more check out the Pydantic [docs](https://docs.pydantic.dev/latest/) on validators.\n",
    "\n",
    "Then we'll bring it all together into the context of RAG from the previous notebook.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "064c286b",
   "metadata": {},
   "source": [
    "Validators will enable us to control outputs by defining a function like so:\n",
    "\n",
    "```python\n",
    "def validation_function(value):\n",
    "    if condition(value):\n",
    "        raise ValueError(\"Value is not valid\")\n",
    "    return mutation(value)\n",
    "```\n",
    "\n",
    "Before we get started lets go over the general shape of a validator:\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7cfc6c66",
   "metadata": {},
   "source": [
    "## Defining Validator Functions\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "d4bb6258-b03a-4621-8a73-29056a20ec0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing_extensions import Annotated\n",
    "from pydantic import BaseModel, AfterValidator, WithJsonSchema\n",
    "\n",
    "\n",
    "def name_must_contain_space(v: str) -> str:\n",
    "    if \" \" not in v:\n",
    "        raise ValueError(\"Name must contain a space.\")\n",
    "    return v\n",
    "\n",
    "def uppercase_name(v: str) -> str:\n",
    "    return v.upper()\n",
    "\n",
    "FullName = Annotated[\n",
    "    str, \n",
    "    AfterValidator(name_must_contain_space), \n",
    "    AfterValidator(uppercase_name),\n",
    "    WithJsonSchema(\n",
    "        {\n",
    "            \"type\": \"string\",\n",
    "            \"description\": \"The user's full name\",\n",
    "        }\n",
    "    )]\n",
    "\n",
    "class UserDetail(BaseModel):\n",
    "    age: int\n",
    "    name: FullName"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "23f8cadd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "UserDetail(age=30, name='JASON LIU')"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "UserDetail(age=30, name=\"Jason Liu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "e4f53ecf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'properties': {'age': {'title': 'Age', 'type': 'integer'},\n",
       "  'name': {'description': \"The user's full name\",\n",
       "   'title': 'Name',\n",
       "   'type': 'string'}},\n",
       " 'required': ['age', 'name'],\n",
       " 'title': 'UserDetail',\n",
       " 'type': 'object'}"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "UserDetail.model_json_schema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "2284a7e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 validation error for UserDetail\n",
      "name\n",
      "  Value error, Name must contain a space. [type=value_error, input_value='Jason', input_type=str]\n",
      "    For further information visit https://errors.pydantic.dev/2.5/v/value_error\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    person = UserDetail.model_validate({\"age\": 24, \"name\": \"Jason\"})\n",
    "except Exception as e:\n",
    "    print(e)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c0302ca",
   "metadata": {},
   "source": [
    "## Using Field\n",
    "\n",
    "We can also use the `Field` class to define validators. This is useful when we want to define a validator for a field that is primative, like a string or integer which supports a limited number of validators.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "3242856f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2 validation errors for UserDetail\n",
      "age\n",
      "  Input should be greater than 0 [type=greater_than, input_value=-10, input_type=int]\n",
      "    For further information visit https://errors.pydantic.dev/2.5/v/greater_than\n",
      "name\n",
      "  Value error, Name must contain a space. [type=value_error, input_value='Jason', input_type=str]\n",
      "    For further information visit https://errors.pydantic.dev/2.5/v/value_error\n"
     ]
    }
   ],
   "source": [
    "from pydantic import Field\n",
    "\n",
    "\n",
    "Age = Annotated[int, Field(gt=0)]\n",
    "\n",
    "class UserDetail(BaseModel):\n",
    "    age: Age\n",
    "    name: FullName\n",
    "\n",
    "try:\n",
    "    person = UserDetail(age=-10, name=\"Jason\")\n",
    "except Exception as e:\n",
    "    print(e)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f689121",
   "metadata": {},
   "source": [
    "## Providing Context\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ec043c23",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 validation error for Response\n",
      "message\n",
      "  Assertion failed, `hurt` was found in the message `I will hurt them.` [type=assertion_error, input_value='I will hurt them.', input_type=str]\n",
      "    For further information visit https://errors.pydantic.dev/2.5/v/assertion_error\n"
     ]
    }
   ],
   "source": [
    "from pydantic import ValidationInfo\n",
    "\n",
    "\n",
    "def message_cannot_have_blacklisted_words(v: str, info: ValidationInfo) -> str:\n",
    "    blacklist = info.context.get(\"blacklist\", [])\n",
    "    for word in blacklist:\n",
    "        assert word not in v.lower(), f\"`{word}` was found in the message `{v}`\"\n",
    "    return v\n",
    "\n",
    "ModeratedStr = Annotated[str, AfterValidator(message_cannot_have_blacklisted_words)]\n",
    "\n",
    "class Response(BaseModel):\n",
    "    message: ModeratedStr\n",
    "\n",
    "\n",
    "try:\n",
    "    Response.model_validate(\n",
    "        {\"message\": \"I will hurt them.\"},\n",
    "        context={\n",
    "            \"blacklist\": {\n",
    "                \"rob\",\n",
    "                \"steal\",\n",
    "                \"hurt\",\n",
    "                \"kill\",\n",
    "                \"attack\",\n",
    "            }\n",
    "        },\n",
    "    )\n",
    "except Exception as e:\n",
    "    print(e)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37e3a638-c9c9-44cd-bcd0-ad1a39f448db",
   "metadata": {},
   "source": [
    "## Using OpenAI Moderation\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88d0b816-7ec8-42b0-9b91-c9aab382c960",
   "metadata": {},
   "source": [
    "To enhance our validation measures, we'll extend the scope to flag any answer that contains hateful content, harassment, or similar issues. OpenAI offers a moderation endpoint that addresses these concerns, and it's freely available when using OpenAI models.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65f46eb5",
   "metadata": {},
   "source": [
    "With the `instructor` library, this is just one function edit away:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "82521112-5301-4442-acce-82b495bd838f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 validation error for Response\n",
      "message\n",
      "  Value error, `I want to make them suffer the consequences` was flagged for harassment, harassment_threatening, violence, harassment/threatening [type=value_error, input_value='I want to make them suffer the consequences', input_type=str]\n",
      "    For further information visit https://errors.pydantic.dev/2.5/v/value_error\n"
     ]
    }
   ],
   "source": [
    "from typing import Annotated\n",
    "from pydantic import AfterValidator\n",
    "from instructor import openai_moderation\n",
    "\n",
    "import instructor\n",
    "from openai import OpenAI\n",
    "\n",
    "client = instructor.patch(OpenAI())\n",
    "\n",
    "# This uses Annotated which is a new feature in Python 3.9\n",
    "# To define custom metadata for a type hint.\n",
    "ModeratedStr = Annotated[str, AfterValidator(openai_moderation(client=client))]\n",
    "\n",
    "\n",
    "class Response(BaseModel):\n",
    "    message: ModeratedStr\n",
    "\n",
    "\n",
    "try:\n",
    "    Response(message=\"I want to make them suffer the consequences\")\n",
    "except Exception as e:\n",
    "    print(e)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "faa5116e",
   "metadata": {},
   "source": [
    "## General Validator\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49d8b772",
   "metadata": {},
   "outputs": [],
   "source": [
    "from instructor import llm_validator\n",
    "\n",
    "HealthTopicStr = Annotated[\n",
    "    str,\n",
    "    AfterValidator(\n",
    "        llm_validator(\n",
    "            \"don't talk about any other topic except health best practices and topics\",\n",
    "            openai_client=client,\n",
    "        )\n",
    "    ),\n",
    "]\n",
    "\n",
    "\n",
    "class AssistantMessage(BaseModel):\n",
    "    message: HealthTopicStr\n",
    "\n",
    "\n",
    "AssistantMessage(\n",
    "    message=\"I would suggest you to visit Sicily as they say it is very nice in winter.\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "050e72fe-4b13-4002-a1d0-94f7b88b784b",
   "metadata": {},
   "source": [
    "### Avoiding hallucination with citations\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3f2869e-c8a3-4b93-82e7-55eb70930900",
   "metadata": {},
   "source": [
    "When incorporating external knowledge bases, it's crucial to ensure that the agent uses the provided context accurately and doesn't fabricate responses. Validators can be effectively used for this purpose. We can illustrate this with an example where we validate that a provided citation is actually included in the referenced text chunk:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "638fc368-5cf7-4ae7-9d3f-efea1b84eec0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 validation error for AnswerWithCitation\n",
      "citation\n",
      "  Value error, Citation `Blueberries contain high levels of protein` not found in text, only use citations from the text. [type=value_error, input_value='Blueberries contain high levels of protein', input_type=str]\n",
      "    For further information visit https://errors.pydantic.dev/2.5/v/value_error\n"
     ]
    }
   ],
   "source": [
    "from pydantic import ValidationInfo\n",
    "\n",
    "def citation_exists(v: str, info: ValidationInfo):\n",
    "    context = info.context\n",
    "    if context:\n",
    "        context = context.get(\"text_chunk\")\n",
    "        if v not in context:\n",
    "            raise ValueError(f\"Citation `{v}` not found in text, only use citations from the text.\")\n",
    "    return v\n",
    "\n",
    "Citation = Annotated[str, AfterValidator(citation_exists)]\n",
    "\n",
    "\n",
    "class AnswerWithCitation(BaseModel):\n",
    "    answer: str\n",
    "    citation: Citation\n",
    "\n",
    "try:\n",
    "    AnswerWithCitation.model_validate(\n",
    "        {\n",
    "            \"answer\": \"Blueberries are packed with protein\",\n",
    "            \"citation\": \"Blueberries contain high levels of protein\",\n",
    "        },\n",
    "        context={\"text_chunk\": \"Blueberries are very rich in antioxidants\"},\n",
    "    )\n",
    "except Exception as e:\n",
    "    print(e)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3064b06b-7f85-40ec-8fe2-4fa2cce36585",
   "metadata": {},
   "source": [
    "Here we assume that there is a \"text_chunk\" field that contains the text that the model is supposed to use as context. We then use the `field_validator` decorator to define a validator that checks if the citation is included in the text chunk. If it's not, we raise a `ValueError` with a message that will be returned to the user.\n",
    "\n",
    "\n",
    "If we want to pass in the context through the `chat.completions.create`` endpoint, we can use the `validation_context` parameter\n",
    "\n",
    "```python\n",
    "resp = client.chat.completions.create(\n",
    "    model=\"gpt-3.5-turbo\",\n",
    "    response_model=AnswerWithCitation,\n",
    "    messages=[\n",
    "        {\"role\": \"user\", \"content\": f\"Answer the question `{q}` using the text chunk\\n`{text_chunk}`\"},\n",
    "    ],\n",
    "    validation_context={\"text_chunk\": text_chunk},\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "64d15ad2",
   "metadata": {},
   "source": [
    "In practice there are many ways to implement this: we could use a regex to check if the citation is included in the text chunk, or we could use a more sophisticated approach like a semantic similarity check. The important thing is that we have a way to validate that the model is using the provided context accurately.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5bbbaa11-32d2-4772-bc31-18d1d6d6c919",
   "metadata": {},
   "source": [
    "## Reasking with validators\n",
    "\n",
    "For most of these examples all we've done we've mostly only defined the validation logic. Which can be seperate from generation, however when we are given validation errors, we shouldn't end there! Instead instructor allows us to collect all the validation errors and reask the llm to rewrite their answer.\n",
    "\n",
    "Lets try to use a extreme example to illustrate this point:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "97f544e7-2552-465c-89a9-a4820f00d658",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"question\": \"What is the meaning of life?\",\n",
      "  \"answer\": \"According to the devil, the meaning of life is a life of sin and debauchery.\"\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "class QuestionAnswer(BaseModel):\n",
    "    question: str\n",
    "    answer: str\n",
    "\n",
    "\n",
    "question = \"What is the meaning of life?\"\n",
    "context = (\n",
    "    \"The according to the devil the meaning of life is a life of sin and debauchery.\"\n",
    ")\n",
    "\n",
    "\n",
    "resp = client.chat.completions.create(\n",
    "    model=\"gpt-3.5-turbo\",\n",
    "    response_model=QuestionAnswer,\n",
    "    messages=[\n",
    "        {\n",
    "            \"role\": \"system\",\n",
    "            \"content\": \"You are a system that answers questions based on the context. answer exactly what the question asks using the context.\",\n",
    "        },\n",
    "        {\n",
    "            \"role\": \"user\",\n",
    "            \"content\": f\"using the context: `{context}`\\n\\nAnswer the following question: `{question}`\",\n",
    "        },\n",
    "    ],\n",
    ")\n",
    "\n",
    "print(resp.model_dump_json(indent=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "0328bbc5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Retrying, exception: 1 validation error for QuestionAnswer\n",
      "answer\n",
      "  Assertion failed, The statement promotes sin and debauchery, which can be considered objectionable. [type=assertion_error, input_value='The meaning of life, acc... of sin and debauchery.', input_type=str]\n",
      "    For further information visit https://errors.pydantic.dev/2.5/v/assertion_error\n",
      "Traceback (most recent call last):\n",
      "  File \"/Users/jasonliu/dev/instructor/instructor/patch.py\", line 277, in retry_sync\n",
      "    return process_response(\n",
      "           ^^^^^^^^^^^^^^^^^\n",
      "  File \"/Users/jasonliu/dev/instructor/instructor/patch.py\", line 164, in process_response\n",
      "    model = response_model.from_response(\n",
      "            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/Users/jasonliu/dev/instructor/instructor/function_calls.py\", line 137, in from_response\n",
      "    return cls.model_validate_json(\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/Users/jasonliu/dev/instructor/.venv/lib/python3.11/site-packages/pydantic/main.py\", line 532, in model_validate_json\n",
      "    return cls.__pydantic_validator__.validate_json(json_data, strict=strict, context=context)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "pydantic_core._pydantic_core.ValidationError: 1 validation error for QuestionAnswer\n",
      "answer\n",
      "  Assertion failed, The statement promotes sin and debauchery, which can be considered objectionable. [type=assertion_error, input_value='The meaning of life, acc... of sin and debauchery.', input_type=str]\n",
      "    For further information visit https://errors.pydantic.dev/2.5/v/assertion_error\n"
     ]
    }
   ],
   "source": [
    "from instructor import llm_validator\n",
    "\n",
    "\n",
    "NotEvilAnswer = Annotated[\n",
    "    str,\n",
    "    AfterValidator(\n",
    "        llm_validator(\"don't say objectionable things\", openai_client=client)\n",
    "    ),\n",
    "]\n",
    "\n",
    "\n",
    "class QuestionAnswer(BaseModel):\n",
    "    question: str\n",
    "    answer: NotEvilAnswer\n",
    "\n",
    "\n",
    "resp = client.chat.completions.create(\n",
    "    model=\"gpt-3.5-turbo\",\n",
    "    response_model=QuestionAnswer,\n",
    "    max_retries=2,\n",
    "    messages=[\n",
    "        {\n",
    "            \"role\": \"system\",\n",
    "            \"content\": \"You are a system that answers questions based on the context. answer exactly what the question asks using the context.\",\n",
    "        },\n",
    "        {\n",
    "            \"role\": \"user\",\n",
    "            \"content\": f\"using the context: `{context}`\\n\\nAnswer the following question: `{question}`\",\n",
    "        },\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "814d3554",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"question\": \"What is the meaning of life?\",\n",
      "  \"answer\": \"The meaning of life is subjective and can vary depending on one's beliefs and perspectives. According to the devil, it is a life of sin and debauchery. However, this viewpoint may not be universally accepted and should be evaluated critically.\"\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "print(resp.model_dump_json(indent=2))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
